/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.solr.client.solrj.io.graph;

import static org.apache.solr.common.params.CommonParams.SORT;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.eq.FieldEqualitor;
import org.apache.solr.client.solrj.io.eq.MultipleFieldEqualitor;
import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.UniqueStream;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.ExecutorUtil;
import org.apache.solr.common.util.SolrNamedThreadFactory;

/**
 * @since 6.1.0
 */
public class ShortestPathStream extends TupleStream implements Expressible {

  private static final long serialVersionUID = 1;

  private String fromNode;
  private String toNode;
  private String fromField;
  private String toField;
  private int joinBatchSize;
  private int maxDepth;
  private String zkHost;
  private String collection;
  private final Deque<Tuple> shortestPaths = new ArrayDeque<>();
  private boolean found;
  private StreamContext streamContext;
  private int threads;
  private SolrParams queryParams;

  public ShortestPathStream(
      String zkHost,
      String collection,
      String fromNode,
      String toNode,
      String fromField,
      String toField,
      SolrParams queryParams,
      int joinBatchSize,
      int threads,
      int maxDepth) {

    init(
        zkHost,
        collection,
        fromNode,
        toNode,
        fromField,
        toField,
        queryParams,
        joinBatchSize,
        threads,
        maxDepth);
  }

  public ShortestPathStream(StreamExpression expression, StreamFactory factory) throws IOException {

    String collectionName = factory.getValueOperand(expression, 0);
    List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
    StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");

    // Collection Name
    if (null == collectionName) {
      throw new IOException(
          String.format(
              Locale.ROOT,
              "invalid expression %s - collectionName expected as first operand",
              expression));
    }

    String fromNode = null;
    StreamExpressionNamedParameter fromExpression = factory.getNamedOperand(expression, "from");

    if (fromExpression == null) {
      throw new IOException(
          String.format(Locale.ROOT, "invalid expression %s - from param is required", expression));
    } else {
      fromNode = ((StreamExpressionValue) fromExpression.getParameter()).getValue();
    }

    String toNode = null;
    StreamExpressionNamedParameter toExpression = factory.getNamedOperand(expression, "to");

    if (toExpression == null) {
      throw new IOException(
          String.format(Locale.ROOT, "invalid expression %s - to param is required", expression));
    } else {
      toNode = ((StreamExpressionValue) toExpression.getParameter()).getValue();
    }

    String fromField = null;
    String toField = null;

    StreamExpressionNamedParameter edgeExpression = factory.getNamedOperand(expression, "edge");

    if (edgeExpression == null) {
      throw new IOException(
          String.format(Locale.ROOT, "invalid expression %s - edge param is required", expression));
    } else {
      String edge = ((StreamExpressionValue) edgeExpression.getParameter()).getValue();
      String[] fields = edge.split("=");
      if (fields.length != 2) {
        throw new IOException(
            String.format(
                Locale.ROOT,
                "invalid expression %s - edge param separated by and = and must contain two fields",
                expression));
      }
      fromField = fields[0].trim();
      toField = fields[1].trim();
    }

    int threads = 6;

    StreamExpressionNamedParameter threadsExpression =
        factory.getNamedOperand(expression, "threads");

    if (threadsExpression != null) {
      threads =
          Integer.parseInt(((StreamExpressionValue) threadsExpression.getParameter()).getValue());
    }

    int partitionSize = 250;

    StreamExpressionNamedParameter partitionExpression =
        factory.getNamedOperand(expression, "partitionSize");

    if (partitionExpression != null) {
      partitionSize =
          Integer.parseInt(((StreamExpressionValue) partitionExpression.getParameter()).getValue());
    }

    int maxDepth = 0;

    StreamExpressionNamedParameter depthExpression =
        factory.getNamedOperand(expression, "maxDepth");

    if (depthExpression == null) {
      throw new IOException(
          String.format(
              Locale.ROOT, "invalid expression %s - maxDepth param is required", expression));
    } else {
      maxDepth =
          Integer.parseInt(((StreamExpressionValue) depthExpression.getParameter()).getValue());
    }

    ModifiableSolrParams params = new ModifiableSolrParams();
    for (StreamExpressionNamedParameter namedParam : namedParams) {
      if (!namedParam.getName().equals("zkHost")
          && !namedParam.getName().equals("to")
          && !namedParam.getName().equals("from")
          && !namedParam.getName().equals("edge")
          && !namedParam.getName().equals("maxDepth")
          && !namedParam.getName().equals("threads")
          && !namedParam.getName().equals("partitionSize")) {
        params.set(namedParam.getName(), namedParam.getParameter().toString().trim());
      }
    }

    // zkHost, optional - if not provided then will look into factory list to get
    String zkHost = null;
    if (null == zkHostExpression) {
      zkHost = factory.getCollectionZkHost(collectionName);
      if (zkHost == null) {
        zkHost = factory.getDefaultZkHost();
      }
    } else if (zkHostExpression.getParameter() instanceof StreamExpressionValue) {
      zkHost = ((StreamExpressionValue) zkHostExpression.getParameter()).getValue();
    }

    if (null == zkHost) {
      throw new IOException(
          String.format(
              Locale.ROOT,
              "invalid expression %s - zkHost not found for collection '%s'",
              expression,
              collectionName));
    }

    // We've got all the required items
    init(
        zkHost,
        collectionName,
        fromNode,
        toNode,
        fromField,
        toField,
        params,
        partitionSize,
        threads,
        maxDepth);
  }

  private void init(
      String zkHost,
      String collection,
      String fromNode,
      String toNode,
      String fromField,
      String toField,
      SolrParams queryParams,
      int joinBatchSize,
      int threads,
      int maxDepth) {
    this.zkHost = zkHost;
    this.collection = collection;
    this.fromNode = fromNode;
    this.toNode = toNode;
    this.fromField = fromField;
    this.toField = toField;
    this.queryParams = queryParams;
    this.joinBatchSize = joinBatchSize;
    this.threads = threads;
    this.maxDepth = maxDepth;
  }

  @Override
  public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {

    StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));

    // collection
    expression.addParameter(collection);

    // parameters
    ModifiableSolrParams mParams = new ModifiableSolrParams(queryParams);
    for (Map.Entry<String, String[]> param : mParams.getMap().entrySet()) {
      String value = String.join(",", param.getValue());

      // SOLR-8409: This is a special case where the params contain a " character
      // Do note that in any other BASE streams with parameters where a " might come into play
      // that this same replacement needs to take place.
      value = value.replace("\"", "\\\"");

      expression.addParameter(new StreamExpressionNamedParameter(param.getKey().toString(), value));
    }

    expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost));
    expression.addParameter(
        new StreamExpressionNamedParameter("maxDepth", Integer.toString(maxDepth)));
    expression.addParameter(
        new StreamExpressionNamedParameter("threads", Integer.toString(threads)));
    expression.addParameter(
        new StreamExpressionNamedParameter("partitionSize", Integer.toString(joinBatchSize)));
    expression.addParameter(new StreamExpressionNamedParameter("from", fromNode));
    expression.addParameter(new StreamExpressionNamedParameter("to", toNode));
    expression.addParameter(new StreamExpressionNamedParameter("edge", fromField + "=" + toField));
    return expression;
  }

  @Override
  public Explanation toExplanation(StreamFactory factory) throws IOException {

    StreamExplanation explanation = new StreamExplanation(getStreamNodeId().toString());

    explanation.setFunctionName(factory.getFunctionName(this.getClass()));
    explanation.setImplementingClass(this.getClass().getName());
    explanation.setExpressionType(ExpressionType.GRAPH_SOURCE);
    explanation.setExpression(toExpression(factory).toString());

    // child is a datastore so add it at this point
    StreamExplanation child = new StreamExplanation(getStreamNodeId() + "-datastore");
    child.setFunctionName("solr (graph)");
    child.setImplementingClass("Solr/Lucene");
    child.setExpressionType(ExpressionType.DATASTORE);
    ModifiableSolrParams mParams = new ModifiableSolrParams(queryParams);
    child.setExpression(
        mParams.getMap().entrySet().stream()
            .map(
                e -> String.format(Locale.ROOT, "%s=%s", e.getKey(), Arrays.toString(e.getValue())))
            .collect(Collectors.joining(",")));
    explanation.addChild(child);

    return explanation;
  }

  @Override
  public void setStreamContext(StreamContext context) {
    this.streamContext = context;
  }

  @Override
  public List<TupleStream> children() {
    List<TupleStream> l = new ArrayList<>();
    return l;
  }

  @SuppressWarnings("UnusedLabel")
  @Override
  public void open() throws IOException {

    List<Map<String, List<String>>> allVisited = new ArrayList<>();
    Map<String, List<String>> visited = new HashMap<>();
    visited.put(this.fromNode, null);

    allVisited.add(visited);
    int depth = 0;
    Map<String, List<String>> nextVisited = null;
    List<Edge> targets = new ArrayList<>();
    ExecutorService threadPool = null;

    try {

      threadPool =
          ExecutorUtil.newMDCAwareFixedThreadPool(
              threads, new SolrNamedThreadFactory("ShortestPathStream"));

      // Breadth first search
      TRAVERSE:
      while (targets.size() == 0 && depth < maxDepth) {
        Set<String> nodes = visited.keySet();
        Iterator<String> it = nodes.iterator();
        nextVisited = new HashMap<>();
        int batchCount = 0;
        List<String> queryNodes = new ArrayList<>();
        List<Future<List<Edge>>> futures = new ArrayList<>();
        JOIN:
        // Queue up all the batches
        while (it.hasNext()) {
          String node = it.next();
          queryNodes.add(node);
          ++batchCount;
          if (batchCount == joinBatchSize || !it.hasNext()) {
            try {
              JoinRunner joinRunner = new JoinRunner(queryNodes);
              Future<List<Edge>> future = threadPool.submit(joinRunner);
              futures.add(future);
            } catch (Exception e) {
              throw new RuntimeException(e);
            }
            batchCount = 0;
            queryNodes = new ArrayList<>();
          }
        }

        try {
          // Process the batches as they become available
          OUTER:
          for (Future<List<Edge>> future : futures) {
            List<Edge> edges = future.get();
            INNER:
            for (Edge edge : edges) {
              if (toNode.equals(edge.to)) {
                targets.add(edge);
                if (nextVisited.containsKey(edge.to)) {
                  List<String> parents = nextVisited.get(edge.to);
                  parents.add(edge.from);
                } else {
                  List<String> parents = new ArrayList<>();
                  parents.add(edge.from);
                  nextVisited.put(edge.to, parents);
                }
              } else {
                if (!cycle(edge.to, allVisited)) {
                  if (nextVisited.containsKey(edge.to)) {
                    List<String> parents = nextVisited.get(edge.to);
                    parents.add(edge.from);
                  } else {
                    List<String> parents = new ArrayList<>();
                    parents.add(edge.from);
                    nextVisited.put(edge.to, parents);
                  }
                }
              }
            }
          }
        } catch (Exception e) {
          throw new RuntimeException(e);
        }

        allVisited.add(nextVisited);
        visited = nextVisited;
        ++depth;
      }
    } finally {
      threadPool.shutdown();
    }

    Set<String> finalPaths = new HashSet<>();
    if (targets.size() > 0) {
      for (Edge edge : targets) {
        List<Deque<String>> paths = new ArrayList<>();
        Deque<String> path = new ArrayDeque<>();
        path.addFirst(edge.to);
        paths.add(path);
        // Walk back up the tree a collect the parent nodes.
        INNER:
        for (int i = allVisited.size() - 1; i >= 0; --i) {
          Map<String, List<String>> v = allVisited.get(i);
          Iterator<Deque<String>> it = paths.iterator();
          List<Deque<String>> newPaths = new ArrayList<>();
          while (it.hasNext()) {
            Deque<String> p = it.next();
            List<String> parents = v.get(p.peekFirst());
            if (parents != null) {
              for (String parent : parents) {
                Deque<String> newPath = new ArrayDeque<>(p);
                newPath.addFirst(parent);
                newPaths.add(newPath);
              }
              paths = newPaths;
            }
          }
        }

        for (Deque<String> p : paths) {
          String s = p.toString();
          if (!finalPaths.contains(s)) {
            Tuple shortestPath = new Tuple("path", p);
            shortestPaths.add(shortestPath);
            finalPaths.add(s);
          }
        }
      }
    }
  }

  private class JoinRunner implements Callable<List<Edge>> {

    private List<String> nodes;
    private List<Edge> edges = new ArrayList<>();

    public JoinRunner(List<String> nodes) {
      this.nodes = nodes;
    }

    @Override
    public List<Edge> call() {

      ModifiableSolrParams joinParams = new ModifiableSolrParams(queryParams);
      String fl = fromField + "," + toField;

      joinParams.set("fl", fl);
      joinParams.set("qt", "/export");
      joinParams.set(SORT, toField + " asc," + fromField + " asc");

      StringBuilder nodeQuery = new StringBuilder();

      for (String node : nodes) {
        nodeQuery.append(node).append(" ");
      }

      String q = fromField + ":(" + nodeQuery.toString().trim() + ")";

      joinParams.set("q", q);
      TupleStream stream = null;
      try {
        stream =
            new UniqueStream(
                new CloudSolrStream(zkHost, collection, joinParams),
                new MultipleFieldEqualitor(
                    new FieldEqualitor(toField), new FieldEqualitor(fromField)));
        stream.setStreamContext(streamContext);
        stream.open();
        BATCH:
        while (true) {
          Tuple tuple = stream.read();
          if (tuple.EOF) {
            break BATCH;
          }
          String _toNode = tuple.getString(toField);
          String _fromNode = tuple.getString(fromField);
          Edge edge = new Edge(_fromNode, _toNode);
          edges.add(edge);
        }
      } catch (Exception e) {
        throw new RuntimeException(e);
      } finally {
        try {
          stream.close();
        } catch (Exception ce) {
          throw new RuntimeException(ce);
        }
      }
      return edges;
    }
  }

  private static class Edge {

    private String from;
    private String to;

    public Edge(String from, String to) {
      this.from = from;
      this.to = to;
    }
  }

  private boolean cycle(String node, List<Map<String, List<String>>> allVisited) {
    // Check all visited trees for each level to see if we've encountered this node before.
    for (Map<String, List<String>> visited : allVisited) {
      if (visited.containsKey(node)) {
        return true;
      }
    }

    return false;
  }

  @Override
  public void close() throws IOException {
    this.found = false;
  }

  @Override
  public Tuple read() throws IOException {
    if (shortestPaths.size() > 0) {
      found = true;
      Tuple t = shortestPaths.removeFirst();
      return t;
    } else {
      Tuple tuple = Tuple.EOF();
      if (!found) {
        tuple.put("sorry", "No path found");
      }
      return tuple;
    }
  }

  @Override
  public int getCost() {
    return 0;
  }

  @Override
  public StreamComparator getStreamSort() {
    return null;
  }
}
